# 有条件的图像生成模块

import torch.nn as nn


class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

        self.nx = 100                           # 输入向量长度
        self.nc = 3                             # IMAGE_CHANNELS
        self.ngf = 64                           # G_FEATURE_MAPS

        self.core = nn.Sequential(
            nn.ConvTranspose2d(in_channels=self.nx, out_channels=self.ngf*8, kernel_size=4, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(self.ngf*8),
            nn.ReLU(True),

            nn.ConvTranspose2d(in_channels=self.ngf*8, out_channels=self.ngf*4, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(self.ngf*4),
            nn.ReLU(True),

            nn.ConvTranspose2d(in_channels=self.ngf*4, out_channels=self.ngf*2, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(self.ngf*2),
            nn.ReLU(True),

            nn.ConvTranspose2d(in_channels=self.ngf*2, out_channels=self.ngf, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(self.ngf),
            nn.ReLU(True),

            nn.ConvTranspose2d(in_channels=self.ngf, out_channels=self.nc, kernel_size=4, stride=2, padding=1, bias=False),
            nn.Tanh(),
        )

    def forward(self, inputs):
        return self.core(inputs)
